Deep Learning: Mini Challenge

Yvo Keller, BSc Data Science, HS23

Introduction¶

My goal with this notebook is to show the direct relation between the paper (Show, Attend and Tell) and it's implementation within this repository. I will try to explain the code as much as possible, nontheless recommend you to read the paper first.

Important: The code in this Notebook (with the expection of plotting captions and EDA), is for explanation purposes, and NOT fit to run in the Notebook. For running code, please refer to the repo's README and directly to the referenced scripts.

Resources:

  • The original paper: Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
  • The implementation is based on the AaronCCWong's implementation. I build on top of his work, trying some alternative techniques (like using pre-trained embeddings) and more closely aligning it with the original paper.

Data Loader¶

Preparing Dataset for the "Show, Attend and Tell" Model¶

In the implementation of the "Show, Attend and Tell" paper, a crucial step is preparing the dataset for training and evaluation. Below, I'll discuss the key components of the generate_json_data script and their significance in the context of the project.

Karpathy's Data Splits for Comparison to Original Paper¶

The script uses Karpathy's data splits, which are a standard way of dividing the dataset in image captioning tasks. This approach divides the dataset into training, validation, and test sets, allowing for a fair comparison with the results reported in the original "Show, Attend and Tell" paper.

Role of Tokenization and Special Tokens¶

Tokenization is the process of converting text into tokens, which are essentially numerical representations of words or characters. This script uses pre-tokenized captions (as indicated by sentence['tokens']). Special tokens such as <start>, <eos>, <unk>, and <pad> are added to the vocabulary. These tokens serve specific purposes:

  • <start>: Marks the beginning of a caption.
  • <eos>: Signifies the end of a caption.
  • <unk>: Represents unknown words not frequent enough in the dataset.
  • <pad>: Used for padding shorter captions to a uniform length.

Maximum Caption Length and Reasoning¶

The max_caption_length parameter, set to 25 by default, defines the maximum length of captions. This length is chosen based on the observation that most captions in the dataset are shorter than 25 tokens. Setting a maximum length is also beneficial for computational efficiency, as it ensures a uniform tensor size for batch processing during training. In case a caption exceeds this length, it is truncated. As start and eos tokens are added after, this results in a final sequence length of 27 for all captions.

Preparing Data for Faster Loading During Training¶

The script processes and saves the data in JSON format, including image paths and corresponding tokenized captions. This preprocessing step speeds up the training process, as the data is already tokenized and split according to the required format. By loading preprocessed data, the model avoids the overhead of performing these operations during training iterations, leading to faster and more efficient training.

Implementation Details¶

  • The script uses argparse to allow easy customization of input parameters like data paths and thresholds for word frequency.
  • It reads the dataset split information from a JSON file and processes each image and its captions based on the specified splits.
  • The word dictionary is created based on the frequency of words in the training set, with a threshold defined by min_word_count (default 5).
  • If a word did not make it into the dictionary, the <unk> token is assigned.
  • Each caption is processed to include the special tokens and to ensure it adheres to the maximum length constraint.

Code¶

This script is located at generate_json_data.py. I wrote a second script, which performs the same preparation steps, just using the BERT tokenizer. That can be found at generate_json_data_bert.py

In [ ]:
import argparse, json
from collections import Counter


def generate_json_data(split_path, data_path, max_captions_per_image, min_word_count, max_caption_length):
    split = json.load(open(split_path, 'r'))
    word_count = Counter()

    train_img_paths = []
    train_caption_tokens = []
    validation_img_paths = []
    validation_caption_tokens = []
    test_img_paths = []
    test_caption_tokens = []

    max_length = 0
    for img in split['images']:
        caption_count = 0
        for sentence in img['sentences']:
            if caption_count < max_captions_per_image:
                caption_count += 1
            else:
                break

            try: # support flickr8k datasets.json that doesn't have subfolders
                img['filepath']
            except KeyError:
                filepath_defined = False
            img_path = f"{data_path}/imgs{'/' + img['filepath'] if filepath_defined else ''}/{img['filename']}"

            if img['split'] == 'train':
                train_img_paths.append(img_path)
                train_caption_tokens.append(sentence['tokens'])
            elif img['split'] == 'val':
                validation_img_paths.append(img_path)
                validation_caption_tokens.append(sentence['tokens'])
            elif img['split'] == 'test':
                test_img_paths.append(img_path)
                test_caption_tokens.append(sentence['tokens'])
            max_length = max(max_length, len(sentence['tokens']))
            word_count.update(sentence['tokens'])

    words = [word for word in word_count.keys() if word_count[word] >= min_word_count]
    word_dict = {word: idx + 4 for idx, word in enumerate(words)}
    word_dict['<start>'] = 0
    word_dict['<eos>'] = 1
    word_dict['<unk>'] = 2
    word_dict['<pad>'] = 3

    with open(data_path + '/word_dict.json', 'w') as f:
        json.dump(word_dict, f)

    max_length = min(max_length, max_caption_length)
    train_captions = process_caption_tokens(train_caption_tokens, word_dict, max_length)
    validation_captions = process_caption_tokens(validation_caption_tokens, word_dict, max_length)
    test_captions = process_caption_tokens(test_caption_tokens, word_dict, max_length)

    with open(data_path + '/train_img_paths.json', 'w') as f:
        json.dump(train_img_paths, f)
    with open(data_path + '/val_img_paths.json', 'w') as f:
        json.dump(validation_img_paths, f)
    with open(data_path + '/train_captions.json', 'w') as f:
        json.dump(train_captions, f)
    with open(data_path + '/val_captions.json', 'w') as f:
        json.dump(validation_captions, f)
    with open(data_path + '/test_img_paths.json', 'w') as f:
        json.dump(test_img_paths, f)
    with open(data_path + '/test_captions.json', 'w') as f:
        json.dump(test_captions, f)

def process_caption_tokens(caption_tokens, word_dict, max_length):
    captions = []
    for tokens in caption_tokens:
        tokens = tokens[:max_length]
        token_idxs = [word_dict[token] if token in word_dict else word_dict['<unk>'] for token in tokens]
        captions.append([word_dict['<start>']] + token_idxs + [word_dict['<eos>']] + [word_dict['<pad>']] * (max_length - len(tokens)))

    return captions


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate json files')
    parser.add_argument('--split-path', type=str, default='data/coco/dataset.json')
    parser.add_argument('--data-path', type=str, default='data/coco')
    parser.add_argument('--max-captions', type=int, default=5,
                        help='maximum number of captions per image')
    parser.add_argument('--min-word-count', type=int, default=5,
                        help='minimum number of occurences of a word to be included in word dictionary')
    parser.add_argument('--max-caption-length', type=int, default=25,
                        help='maximum number of tokens in a caption')
    args = parser.parse_args()

    generate_json_data(args.split_path, args.data_path, args.max_captions, args.min_word_count, args.max_caption_length)

ImageCaptionDataset¶

The next important step is providing the prepared data as a DataLoader. I have focused on optimizing the performance of the data pipeline, which was crucial for efficient training.

Utilizing Pinned Memory¶

In my DataLoader configuration, I've set pin_memory=True. This is a performance optimization in PyTorch that is particularly beneficial when using CUDA-enabled GPUs, or MPS (on M2) in my case. By enabling pinned memory, the DataLoader automatically places the fetched data Tensors in pinned memory, facilitating faster data transfer to the GPU.

Preprocessing Done Once Before Training¶

In the ImageCaptionDataset class, I handle all preprocessing of images and captions before training starts. By performing image loading, transformation, and caption preprocessing just once and storing them in memory, I avoid redundant processing for each epoch or batch during training. This strategy is particularly advantageous for large datasets, as it significantly reduces I/O operations and processing time during the training loops.

Key Features of the ImageCaptionDataset¶

  • Image Loading and Transformation: I load the image paths and captions from JSON files. The pil_loader function is used to load images, which are then transformed and converted to tensors.

  • BERT Embeddings Compatibility: The class can handle both standard and BERT embeddings, in line with the unique aspect of my model, where I experiment with both newly trained and pretrained BERT embeddings.

  • Fractional Dataset Utilization: I've added the capability to use only a fraction of the dataset, controlled by the fraction argument. This feature is proved really useful for quick iterations or debugging.

  • Efficient Data Storage: I store preprocessed image and caption tensors in a list (self.data), which enables fast data retrieval during training.

  • Handling Multiple Captions: The class is designed to efficiently provide all 5 captions that exist for the image during training, in addition to the one caption currently training on. This is relevant for calcuating the BLEU Score in the Valdiation and Testing Phase.

Code¶

This script is located at dataset.py.

In [ ]:
import json
import torch
from torch.utils.data import Dataset
from collections import defaultdict
from PIL import Image
import json


def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class ImageCaptionDataset(Dataset):
    def __init__(self, transform, data_path, split_type='train', fraction=1.0, bert=False):
        super(ImageCaptionDataset, self).__init__()
        self.transform = transform
        
        # Load image paths and captions
        img_paths = json.load(open(data_path + f'/{split_type}_img_paths.json', 'r'))
        if bert:
            captions = json.load(open(data_path + f'/{split_type}_captions_bert.json', 'r'))
        else:
            captions = json.load(open(data_path + f'/{split_type}_captions.json', 'r'))

        # Reduce dataset size if fraction is not 1.0
        if fraction != 1.0:
            img_paths = img_paths[:int(len(img_paths) * fraction)]
            captions = captions[:int(len(captions) * fraction)]

        # Preprocess and store data
        self.data = []
        all_captions = defaultdict(list)  # Store all captions for each image path

        for img_path, caption in zip(img_paths, captions):
            img = pil_loader(img_path)
            if self.transform is not None:
                img = self.transform(img)
            self.data.append((torch.FloatTensor(img), torch.tensor(caption)))
            all_captions[img_path].append(caption)

        # Convert all_captions dictionary to a list matching the order of images
        self.all_captions = [all_captions[path] for path in img_paths]

    def __getitem__(self, index):
        img_tensor, caption_tensor = self.data[index]
        all_captions_tensor = torch.tensor(self.all_captions[index])
        return img_tensor, caption_tensor, all_captions_tensor

    def __len__(self):
        return len(self.data)

Explorative Data Analysis¶

In this section, I perform some explorative data analysis on the dataset that is later used for training. Thus, the dataset already includes all changes to the dataset (truncation to max caption length, min word count of 5 etc.) defined in the prior Data Loader section.

In [ ]:
import json
import spacy
import nltk
from nltk.corpus import stopwords
import pandas as pd
from wordcloud import WordCloud
import matplotlib.pyplot as plt

nltk.download('stopwords')

!python -m spacy download en_core_web_sm
en_core_web_sm = spacy.load("en_core_web_sm")
In [16]:
DATA_PATH = 'data/flickr8k/'
SPLIT_TYPES = ['train', 'val', 'test']

captions = []
for split_type in SPLIT_TYPES:
    captions.extend(json.load(open(DATA_PATH + f'/{split_type}_captions.json', 'r')))

print(f'Loaded {len(captions)} captions')

word_dict = json.load(open(DATA_PATH + '/word_dict.json', 'r'))
vocabulary_size = len(word_dict)

print(f'Vocabulary size: {vocabulary_size}')

# decode captions
decoded_captions = []
for caption in captions:
    decoded_caption = []
    for idx in caption:
        if idx == 0:
            continue
        elif idx == 1:
            break
        else:
            decoded_caption.append(list(word_dict.keys())[list(word_dict.values()).index(idx)])
    decoded_captions.append(decoded_caption)
Loaded 40000 captions
Vocabulary size: 2945
In [34]:
def generate_count_wordcloud(df:pd.DataFrame, top_n_words=30):
    en_stopwords = set(stopwords.words('english'))

    # create custom tokenizer that removes stopwords
    def spacy_tokenizer(text):
        tokens = en_core_web_sm(text)
        return [token for token in tokens if token.text not in en_stopwords]

    tokens = df.text.apply(spacy_tokenizer)
    lowercase_tokens = [token.lower_ for doc in tokens for token in doc]

    # create wordcloud
    wordcloud = WordCloud(
        width=800, height=400, background_color="white", max_words=top_n_words
    ).generate(" ".join(lowercase_tokens))

    # show wordcloud
    plt.figure(figsize=(12, 10))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title(f'Word Cloud Flickr8k Image Captions')
    plt.show()

# create dataframe with captions
decoded_captions_joined = [' '.join(caption) for caption in decoded_captions]
df = pd.DataFrame({'text': decoded_captions_joined})

print(f'Sample decoded caption: {decoded_captions_joined[0]}')

# generate wordcloud
generate_count_wordcloud(df)
Sample decoded caption: a black dog is running after a white dog in the snow
In [33]:
# Distribibution of caption lengths
caption_lengths = [len(caption) for caption in decoded_captions]
df_cl = pd.DataFrame({'caption_length': caption_lengths})

plt.figure()
plt.bar(df_cl.caption_length.value_counts().index, df_cl.caption_length.value_counts())
plt.title('Caption Length Distribution')
plt.xlabel('Caption Length')
plt.ylabel('Count')
plt.show()

Captions¶

As we can see, the average caption length is around 12 words. The word shows that topics like dog, person, woman, boy, little girl etc. are very common in the Flickr8k dataset. Thus, the model should be able to learn these topics very well.

Images¶

Flickr8k contains a total of 8000 images with 5 captions each. The images are of different sizes, but are all scaled to 224x224 pixels before training.

Understanding Show, Attend and Tell¶

In this section, I focus on building an understanding of the concepts presented in the paper, and connecting it to their implementation in PyTorch.

The Encoder¶

Section 3.1.1 in the "Show, Attend and Tell" paper describes the encoder. The proposed model uses a CNN to extract a set of feature vectors, referred to as annotation vectors. This code implements that concept by using pre-trained models (VGG19, ResNet152, or DenseNet161), modifying them to exclude the final classification layers, and reshaping the output to form a set of feature vectors (our annotation vectors). Each vector represents different spacial parts of the image, which is key for the attention mechanism in the next stages of the model.

Specifically for VGG19, which I will be using in combination with the Flickr8k dataset, this means extracting the CNNs layers until just before the last pooling layer. This is done by using the features part of the VGG19 model, which is a Sequential object. We then remove the last layer from that object, which is the last pooling layer.

VGG(
  (features): Sequential(
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True) <<<<< this is the last layer we use
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    ...
  )
)

Code¶

Here is the full Encoder class, which is located at encoder.py.

import torch.nn as nn
from torchvision.models import densenet161, resnet152, vgg19
from torchvision.models import VGG19_Weights

class Encoder(nn.Module):
    """
    Encoder network for image feature extraction, follows section 3.1.1 of the paper
    """
    def __init__(self, network='vgg19'):
        super(Encoder, self).__init__()
        self.network = network
        # Selection of pre-trained CNNs for feature extraction
        if network == 'resnet152':
            self.net = resnet152(pretrained=True)
            # Removing the final fully connected layers of ResNet152
            self.net = nn.Sequential(*list(self.net.children())[:-2])
            self.dim = 2048  # Dimension of feature vectors for ResNet152
        elif network == 'densenet161':
            self.net = densenet161(pretrained=True)
            # Removing the final layers of DenseNet161
            self.net = nn.Sequential(*list(list(self.net.children())[0])[:-1])
            self.dim = 1920  # Dimension of feature vectors for DenseNet161
        else:
            self.net = vgg19(weights=VGG19_Weights.DEFAULT)
            # Using features from VGG19, excluding the last pooling layer
            self.net = nn.Sequential(*list(self.net.features.children())[:-1])
            self.dim = 512  # Dimension of feature vectors for VGG19

            # Freezing the weights of the pre-trained CNN
            for params in self.net.parameters():
                params.requires_grad = False

    def forward(self, x):
        x = self.net(x)
        # These steps correspond to the extraction of annotation vectors (a = {a1,...,aL}) as described in Section 3.1.1 of the paper.
        # 1. Change the order from (BS, C, H, W) to (BS, H, W, C) in prep for reshaping
        x = x.permute(0, 2, 3, 1)
        # 2. Reshape to [BS, num_spatial_features, C], the -1 effectively flattens the height and width dimensions into a single dimension
        x = x.view(x.size(0), -1, x.size(-1))
        return x

The Decoder¶

Let's move forward to Section 3.1.2, explaining the Decoder.

On a high level, it works as follows:

  • The Decoder uses the extracted image features to initialize the states of the LSTM cell (by averaging them)
  • Then it employs an attention mechanism at each time step to focus on different parts of the image while generating the caption
  • The Decoder predicts one word of the caption at each time step, and its prediction is conditioned on the current LSTM state, the context vector from the attention mechanism, and the previous word.
  • Teacher forcing (a common technique in training sequence generation models where the ground truth word is fed as the next input instead of the model's prediction) is optionally used if tf=True.

Code¶

Provided below is the full implementation of the Decoder. I will break the code down into smaller chunks to explore it in more detail, based on the paper's description of the Decoder.

The full code is located at decoder.py.

import torch
import torch.nn as nn
from attention import Attention


class Decoder(nn.Module):
    def __init__(self, vocabulary_size, encoder_dim, tf=False, ado=False, bert=False, attention=False):
        super(Decoder, self).__init__()
        self.use_tf = tf
        self.use_advanced_deep_output = ado
        self.use_bert = bert
        self.use_attention = attention

        # Initializing parameters
        self.encoder_dim = encoder_dim

        # Embeddings
        if bert == True:
            from transformers import BertModel, BertTokenizer
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.vocabulary_size = self.bert_model.config.vocab_size
            self.embedding_size = self.bert_model.config.hidden_size # 768

            # Embedding layer using BERT's embeddings
            self.embedding = self.bert_model.get_input_embeddings()

            # Freeze the BERT embeddings
            for param in self.embedding.parameters():
                param.requires_grad = False

            # Delete the BERT model to save memory (and checkpoint size)
            del self.bert_model
        else:
            self.vocabulary_size = vocabulary_size
            self.embedding_size = 512
            self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_size)  # Embedding layer for input words

        # Initial LSTM cell state generators
        self.init_h = nn.Linear(encoder_dim, self.embedding_size)  # For hidden state
        self.init_c = nn.Linear(encoder_dim, self.embedding_size)  # For cell state
        self.tanh = nn.Tanh()

        # Attention mechanism related layers
        self.f_beta = nn.Linear(self.embedding_size, encoder_dim)  # Gating scalar in attention mechanism
        self.sigmoid = nn.Sigmoid()

        # Attention and LSTM components
        self.attention = Attention(encoder_dim, self.embedding_size)  # Attention network
        self.lstm = nn.LSTMCell(self.embedding_size + encoder_dim, self.embedding_size)  # LSTM cell

        # Deep output layers
        if self.use_advanced_deep_output:
            # Advanced DO: Layers for transforming LSTM state, context vector and embedding for DO-RNN
            hidden_dim, intermediate_dim = self.embedding_size, self.embedding_size
            self.f_h = nn.Linear(hidden_dim, intermediate_dim)  # Transforms LSTM hidden state
            self.f_z = nn.Linear(encoder_dim, intermediate_dim)  # Transforms context vector
            self.f_out = nn.Linear(intermediate_dim, self.vocabulary_size)  # Transforms the combined vector (sum of embedding, LSTM state, and context vector) to voc_size
            self.relu = nn.ReLU()  # Activation function
            self.dropout = nn.Dropout()

        # Simple DO: Layer for transforming LSTM state to vocabulary
        self.deep_output = nn.Linear(self.embedding_size, self.vocabulary_size)  # Maps LSTM outputs to vocabulary
        self.dropout = nn.Dropout()

    def forward(self, img_features, captions):
        # Forward pass of the decoder
        batch_size = img_features.size(0)

        # Initialize LSTM state
        h, c = self.get_init_lstm_state(img_features)

        # Teacher forcing setup
        max_timespan = max([len(caption) for caption in captions]) - 1

        if self.use_bert:
            start_token = torch.full((batch_size, 1), self.tokenizer.cls_token_id).long().to(mps_device)
        else:
            start_token = torch.zeros(batch_size, 1).long().to(mps_device)

        # Convert caption tokens to their embeddings
        if self.use_tf:
            caption_embedding = self.embedding(captions)
        else:
            previous_predicted_token_embedding = self.embedding(start_token)

        # Preparing to store predictions and attention weights
        preds = torch.zeros(batch_size, max_timespan, self.vocabulary_size).to(mps_device) # [BATCH_SIZE, TIME_STEPS, VOC_SIZE]
        alphas = torch.zeros(batch_size, max_timespan, img_features.size(1)).to(mps_device) # [BATCH_SIZE, TIME_STEPS, NUM_SPATIAL_FEATURES]

        # Generating captions
        for t in range(max_timespan):
            if self.use_attention:
                context, alpha = self.attention(img_features, h)  # Compute context vector via attention
                gate = self.sigmoid(self.f_beta(h))  # Gating scalar for context
                gated_context = gate * context  # Apply gate to context
            else:
                # If not using attention, treat all parts of the image equally
                alpha = torch.full((batch_size, img_features.size(1)), 1.0 / img_features.size(1), device=mps_device)  # Uniform attention
                context = img_features.mean(dim=1)  # Simply take the mean of the image features
                gated_context = context  # No gating applied

            # Prepare LSTM input
            if self.use_tf:
                lstm_input = torch.cat((caption_embedding[:, t], gated_context), dim=1)  # current embedding + context vector as input vector
            else:
                previous_predicted_token_embedding = previous_predicted_token_embedding.squeeze(1) if previous_predicted_token_embedding.dim() == 3 else previous_predicted_token_embedding
                lstm_input = torch.cat((previous_predicted_token_embedding, gated_context), dim=1)

            # LSTM forward pass
            h, c = self.lstm(lstm_input, (h, c))

            # Generate word prediction
            if self.use_advanced_deep_output:
                # NOTE: could explore alternative positions for dropout
                if self.use_tf:
                    output = self.advanced_deep_output(self.dropout(h), context, caption_embedding[:, t])
                else:
                    output = self.advanced_deep_output(self.dropout(h), context, previous_predicted_token_embedding)
            else:
                output = self.deep_output(self.dropout(h))

            preds[:, t] = output  # Store predictions
            alphas[:, t] = alpha  # Store attention weights

            # Prepare next input word
            if not self.use_tf:
                predicted_token_idxs = output.max(1)[1].reshape(batch_size, 1) # output.max(1)[1] = extract index: [1] of the token with the highest probability: max(1)
                previous_predicted_token_embedding = self.embedding(predicted_token_idxs)

        return preds, alphas

    def get_init_lstm_state(self, img_features):
        # Initializing LSTM state based on image features
        avg_features = img_features.mean(dim=1)

        c = self.init_c(avg_features)  # Cell state
        c = self.tanh(c)

        h = self.init_h(avg_features)  # Hidden state
        h = self.tanh(h)

        return h, c

    def advanced_deep_output(self, h, context, current_embedding):
        # Combine the LSTM state and context vector
        h_transformed = self.relu(self.f_h(h))
        z_transformed = self.relu(self.f_z(context))

        # Sum the transformed vectors with the embedding
        combined = h_transformed + z_transformed + current_embedding

        # Transform the combined vector & compute the output word probability
        return self.relu(self.f_out(combined))

Attention Mechanisms¶

At the heart of the Show, Attend and Tell model is the attention mechanism. The attention mechanism is used to focus on different parts of the image while generating the caption. The attention mechanism is implemented as a separate module, which is used by the Decoder at each time step.

There are two main types of attention mechanisms: soft attention and hard attention. Soft attention is differentiable and allows for end-to-end training, while hard attention is non-differentiable and requires reinforcement learning to train. I want to break down both types of attention mechanisms theoretically and then show how they are implemented in the code.

Stochastic Hard Attention¶

Stochastic "Hard" Attention is an approach where the model discretely chooses specific regions (or locations, i.e. one annotation vector) in an image to focus on at each step of generating a caption. This contrasts with "Soft" Attention, where the model considers all regions but with varying degrees of focus.

1. Attention Location Representation¶

$ s_{t, i} = 1 $ if the $ i $-th location is chosen at time $ t $, out of $ L $ total locations.

Significance: Represents the model's decision on where to focus in the image when generating the $ t $-th word in the caption as a one-hot vector.

2. Context Vector Computation¶

$$ \hat{\mathbf{z}}_t = \sum_i s_{t, i} \mathbf{a}_i $$

Significance: Computes the context vector as the feature vector of the selected image region. Only the chosen region contributes to the context at each step.

3. Attention as Multinoulli Distribution¶

$$ \tilde{s}_t \sim \operatorname{Multinoulli}_L(\{\alpha_i\}) $$

Significance: Models the attention decision as a random variable, following a Multinoulli distribution. The attention weights $ \{\alpha_i\} $ determine the probability of focusing on each region.

4. Objective Function (Variational Lower Bound)¶

$$ L_s = \sum_s p(s | \mathbf{a}) \log p(\mathbf{y} | s, \mathbf{a}) $$$$ L_s \leq \log p(\mathbf{y} | \mathbf{a}) $$

where...

  • The inequality $ L_s \leq \log p(\mathbf{y} \mid \mathbf{a}) $ indicates that $ L_s $ is a lower bound on the log-likelihood. Lower bound means it is always less than or equal to the true log probability of the caption given the image.
  • The objective function $ L_s $ involves summing over all possible attention sequences, but since we can't compute this exactly, we use a weighted sum where the weights are the probabilities of each attention sequence: $ p(s | \mathbf{a}) $.

Significance: $ L_s $ serves as a computationally feasible approximation to the true log-likelihood of generating the correct caption. It is about finding the best possible set of attention decisions (where to focus in the image at each step) to maximize the probability of correctly generating the caption sequence. It's optimized during training to improve captioning accuracy.

5. Gradient Approximation via Monte Carlo Sampling¶

$$ \frac{\partial L_s}{\partial W} \approx \frac{1}{N} \sum_{n=1}^N\left[\frac{\partial \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a})}{\partial W} + \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a}) \frac{\partial \log p(\tilde{s}^n | \mathbf{a})}{\partial W}\right] $$

where...

  • The gradient of $ L_s $ is approximated as an average over $ N $ sampled sequences of attention decisions.
  • $ \frac{\partial \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a})}{\partial W} $ is the gradient of the log likelihood of the generated word sequence given the sampled attention sequence and the image features
  • $ \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a}) $ is the log likelihood of the word sequence given the sampled attention sequence and the image features
  • $ \frac{\partial \log p(\tilde{s}^n | \mathbf{a})}{\partial W} $ is the gradient of the log probability of the sampled attention sequence given the image features

Significance: Provides a practical method to approximate the gradient of $ L_s $ for model optimization, as direct computation is infeasible due to the stochastic nature of hard attention.

6. Variance Reduction Techniques¶

Moving Average Baseline $$ b_k = 0.9 \times b_{k-1} + 0.1 \times \log p(\mathbf{y} | \tilde{s}_k, \mathbf{a}) $$

where...

  • $ b_k $ represents the moving average baseline at the $ k $-th mini-batch during training.

  • The formula for $ b_k $ involves an exponential decay component, which is a method commonly used to calculate a moving average that gives more weight to recent observations. In this case, the decay is controlled by the coefficient $ 0.9 $. This coefficient multiplies the previous baseline $ b_{k-1} $, effectively reducing its influence over time.

  • Significance: Reduces the variance in the Monte Carlo estimator of the gradient, stabilizing training.

Entropy Regularization $$ \lambda_e \frac{\partial H[\tilde{s}^n]}{\partial W} $$

where...

  • $ H[\tilde{s}^n] $ is the entropy of the sampled attention sequence $ \tilde{s}^n $. By adding the entropy of the attention distribution to the objective function, the model is encouraged to maintain a degree of uncertainty in its attention decisions. This encouragement for higher entropy effectively promotes exploration in the model's attention mechanism. Instead of always focusing on the same regions for similar images or features, the model is nudged to explore other potentially informative regions as well.

  • $ \lambda_e $ is a hyperparameter controlling the strength of the entropy regularization.

  • Significance: Encourages exploration in attention decisions, further reducing variance and improving model robustness. A model that explores more diverse attention strategies is less likely to get stuck in local optima and can generalize better.

7. Final Learning Rule¶

$$ \frac{\partial L_s}{\partial W} \approx \frac{1}{N} \sum_{n=1}^N\left[\frac{\partial \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a})}{\partial W} + \lambda_r(\log p(\mathbf{y} | \tilde{s}^n, \mathbf{a}) - b) \frac{\partial \log p(\tilde{s}^n | \mathbf{a})}{\partial W} + \lambda_e \frac{\partial H[\tilde{s}^n]}{\partial W}\right] $$

where...

  • $ \frac{\partial \log p(\mathbf{y} | \tilde{s}^n, \mathbf{a})}{\partial W} $ is the gradient of the objective function $ L_s $ with respect to the model parameters $ W $
  • $ \lambda_r(\log p(\mathbf{y} | \tilde{s}^n, \mathbf{a}) - b) \frac{\partial \log p(\tilde{s}^n | \mathbf{a})}{\partial W} $ resembles the REINFORCE learning rule from reinforcement learning
  • $ \lambda_r $ is a hyperparameter that controls the influence of the reinforcement learning-based reward signal in the training process. It adjusts the balance between following the gradient of the attention model’s log probability and the reinforcement learning-based reward signal.
  • $ \lambda_e $ is a hyperparameter controlling the strength of the entropy regularization.

​

Significance: Combines all elements (gradient approximation, baseline, and entropy regularization) into a single learning rule for training the model with hard attention.

Connection to REINFORCE Learning Rule¶

This approach aligns with the REINFORCE rule from reinforcement learning, treating the sequence of attention decisions as actions with associated rewards based on the log likelihood of the generated caption.

Deterministic Soft Attention¶

Unlike stochastic hard attention, which involves random sampling (where the model discretely chooses specific regions to focus on), soft attention deterministically calculates a weighted sum of all parts of the input, allowing for straightforward optimization and learning. Thus, in soft attention, the model considers all regions (or locations, i.e. all annotation vectors) in an image at each step of generating a caption, but with varying degrees of focus.

1. Expectation of the Context Vector¶

$$ \mathbb{E}_{p(s_t \mid \mathbf{a})}[\hat{\mathbf{z}}_t] = \sum_{i=1}^L \alpha_{t, i} \mathbf{a}_i $$

where...

  • The weights $ \alpha_{t, i} $ are the attention probabilities for each region at time step $ t $

  • $ L $ is the total number of regions.

  • Explanation: This formula represents the expected context vector as a weighted sum of all annotation vectors $ \mathbf{a}_i $ from the image.

  • Significance: It provides a 'soft' focus by blending information from all parts of the image, with more emphasis on the areas deemed most relevant by the model.

2. Deterministic Attention Model¶

  • Concept: The soft attention mechanism is deterministic because it doesn't involve random sampling. Instead, it uses a predictable, continuous function (the weighted sum) to determine the focus.
  • Significance: This deterministic nature makes the entire model, including the attention mechanism, smooth and differentiable, allowing for standard backpropagation during training.

3. Normalized Weighted Geometric Mean (NWGM)¶

$$ N W G M\left[p\left(y_t=k \mid \mathbf{a}\right)\right] = \frac{\exp \left(\mathbb{E}_{p\left(s_t \mid a\right)}\left[n_{t, k}\right]\right)}{\sum_j \exp \left(\mathbb{E}_{p\left(s_t \mid a\right)}\left[n_{t, j}\right]\right)} $$
  • Explanation: NWGM is used for calculating the probability of the next word in the caption. It approximates the softmax probability distribution over possible next words by applying softmax to the expectations of the underlying linear projections.
  • Significance: This approach aligns with the standard mechanism for generating predictions in neural networks, facilitating efficient training and prediction.

4. Simplification for Learning¶

  • Concept: Learning with deterministic soft attention is more straightforward than with stochastic hard attention. The expected context vector can be directly used in forward propagation, and standard backpropagation can be applied for training.
  • Significance: This simplification means that models with soft attention can be trained efficiently with conventional optimization algorithms, making them practical for large-scale applications.

5. Approximation to Marginal Likelihood¶

  • Concept: Deterministic soft attention can be seen as an approximation to optimizing the marginal likelihood over attention locations, which is a complex problem in the stochastic hard attention setting.
  • Significance: It provides a practical and computationally efficient way to capture the benefits of attention mechanisms without the need for complex sampling or optimization methods required by stochastic hard attention.

Conclusion¶

Deterministic Soft Attention offers a practical and efficient method for implementing attention mechanisms in neural networks, especially for tasks like image captioning. By calculating a weighted sum of input features and avoiding the complexity of stochastic sampling, it facilitates smooth and differentiable models that are amenable to standard training techniques. This approach enables the model to effectively focus on relevant parts of the input while maintaining computational tractability and ease of training.

Implementation of the Attention Mechanism¶

Enough theory, let's apply this knowledge to the code. The attention mechanism is implemented as a separate module used by the decoder at each time step. I use the deterministic soft attention mechanism described above. The reason for this is that implementing stochastic hard attention is dramatically more complex and would require reinforcement learning to train the model, which is beyond the scope of the deep learning course for which I am implementing this project. It is worth noting, however, that Stochastic Hard Attention performed slightly better than Deterministic Soft Attention in the original paper, as measured by the BLEU score.

import torch
import torch.nn as nn


class Attention(nn.Module):
    def __init__(self, encoder_dim):
        super(Attention, self).__init__()
        self.U = nn.Linear(512, 512)
        self.W = nn.Linear(encoder_dim, 512)
        self.v = nn.Linear(512, 1)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(1)

    def forward(self, img_features, hidden_state):
        U_h = self.U(hidden_state).unsqueeze(1)
        W_s = self.W(img_features)
        att = self.tanh(W_s + U_h)
        e = self.v(att).squeeze(2)
        alpha = self.softmax(e)
        context = (img_features * alpha.unsqueeze(2)).sum(1)
        return context, alpha

Calculating Attention Weights and Context Vector¶

Formula: $ \mathbb{E}_{p(s_t \mid \mathbf{a})}[\hat{\mathbf{z}}_t] = \sum_{i=1}^L \alpha_{t, i} \mathbf{a}_i $

Code Implementation (in Attention class):

  • Attention Weights Calculation:

    U_h = self.U(hidden_state).unsqueeze(1)
    W_s = self.W(img_features)
    att = self.tanh(W_s + U_h)
    e = self.v(att).squeeze(2)
    alpha = self.softmax(e)
    

    Here, U_h and W_s are the transformed hidden state and image features, respectively. alpha is the attention probability for each region in the image.

  • Context Vector Calculation:

    context = (img_features * alpha.unsqueeze(2)).sum(1)
    

    This line then computes the weighted sum of the image features based on the attention weights, resulting in the context vector.

Changes to the Implementation¶

In an effort to make the implementation more closely align with the paper and add new features, I made many changes to the base implementation.

I will highlight the most important ones here, before going into more detail below:

  • BERT Embeddings: I added the capability to use BERT embeddings instead of the standard embeddings. This is a key difference from the original paper, where the authors used standard embeddings. I wanted to explore the impact of using BERT embeddings, which are pretrained on a large corpus of text, on the performance of the model. I also wanted to see if using BERT embeddings would allow for faster convergence during training, as the model would not have to learn the embeddings from scratch.
  • Advanced Deep Output: I implemented the advanced deep output layer, which is described in Section 3.1.2 of the paper. This layer transforms the LSTM state, context vector, and embedding into a single vector, which is then used to predict the next word in the caption. This is an alternative to the simple deep output layer, which only uses the LSTM state to predict the next word. I wanted to explore the impact of using the advanced deep output layer, which is more complex, on the performance of the model.

A non comprehensive list of other changes I made to the implementation:

  • Fix accuracy calculation in utils.py to not be artificially inflated by padding tokens being counted as correct.
  • Fix BLEU score calculation to be based on the predicted caption cut off at the first <eos> token occurence, instead of letting the model predict to the max sequence length and basing the BLEU score on that.
  • Add the capability to use a fraction of the dataset, controlled by the fraction argument in the ImageCaptionDataset class. This is useful for quick iterations or debugging.
  • Performance optimizations at various points in the code (led to a ~7x speedup in training time)
    • Preprocessing of images and captions is done once before training, and the preprocessed data is stored in memory. This avoids redundant processing for each epoch or batch during training.
    • Pinned memory is used in the DataLoader to speed up data transfer to the GPU.
    • Vectorited the calculate_caption_lengths function in utils.py, as it was a significant bottleneck which pytorch bottleneck profiler highlighted.

Deep Output Layer¶

Explanation¶

As described towards the end of section 3.1.2, Show, Attend and Tell utilizes a deep output layer (Pascanu et al., 2014) to compute the output word probability given given the current state of the LSTM, the context vector from the attention mechanism, and the previously generated word.

Let's break down this formula and map its components to the code:

$$ p\left(\mathbf{y}_t \mid \mathbf{a}, \mathbf{y}_1^{t-1}\right) \propto \exp \left(\mathbf{L}_o\left(\mathbf{E} \mathbf{y}_{t-1}+\mathbf{L}_h \mathbf{h}_t+\mathbf{L}_z \hat{\mathbf{z}}_t\right)\right) $$

Where p of $ \mathbf{y}_t $ is the probability of the output word $ y $ at time $ _t $ given the image features $ \mathbf{a} $ and the previously generated words $ \mathbf{y}_1^{t-1} $.

In this formula:

  • $ \mathbf{y}_t $ is the output word at time $ t $.
  • $ \mathbf{a} $ represents the set of annotation vectors (image features).
  • $ \mathbf{y}_1^{t-1} $ are the previously generated words up to time $ t-1 $.
  • $ \mathbf{L}_o, \mathbf{L}_h, \mathbf{L}_z $ are learned weight matrices (initialized randomly).
  • $ \mathbf{E} $ is the embedding matrix for the previous word $ \mathbf{y}_{t-1} $.
  • $ \mathbf{h}_t $ is the hidden state of the LSTM at time $ t $.
  • $ \hat{\mathbf{z}}_t $ is the context vector at time $ t $, generated by the attention mechanism.

Now let's map this to the Decoder's code in forward():

  1. Embedding of the Previous Word ($ \mathbf{E} \mathbf{y}_{t-1} $): This is done using the self.embedding layer in the code.

    embedding = self.embedding(prev_words)
    
  2. Hidden State of the LSTM ($ \mathbf{h}_t $): The h variable in the code represents the hidden state of the LSTM at each time step.

    h, c = self.lstm(lstm_input, (h, c))
    
  3. Context Vector ($ \hat{\mathbf{z}}_t $): The context vector is computed by the attention mechanism in the self.attention layer.

    context, alpha = self.attention(img_features, h)
    
  4. Combining and Transforming for Output Prediction: The output word probability is computed by combining these elements and applying the learned weight matrices. Here, this operation is currently condensed into one self.deep_output layer transforming just the hidden state $ \mathbf{h}_t $. In a more complex or literal implementation of the Deep Output Layer as layed out in Show, Attend and Tell, you would expect to see multiple such layers, each followed by a non-linear activation function.

    output = self.deep_output(self.dropout(h))
    

Implementing the Deep Output Layer¶

As we saw above, the paper describes the deep-output RNN as having multiple layers, each followed by a non-linear activation function. The implementation by AaronCCWong only had one layer transforming the hidden state of the LSTM. Therfore, I implemented the deep output as described in the paper, with multiple layers and non-linear activations. This can be enabled by setting use_advanced_deep_output=True flag when training the model.

$$ p\left(\mathbf{y}_t \mid \mathbf{a}, \mathbf{y}_1^{t-1}\right) \propto \exp \left(\mathbf{L}_o\left(\mathbf{E} \mathbf{y}_{t-1}+\mathbf{L}_h \mathbf{h}_t+\mathbf{L}_z \hat{\mathbf{z}}_t\right)\right) $$

Where:

  • $ \mathbf{L}_o, \mathbf{L}_h, \mathbf{L}_z $ are learned weight matrices for transforming the embedding, hidden state, and context vector respectively.
  • $ \exp ( ) $ represents the softmax function
class Decoder(nn.Module):
    def __init__(self, vocabulary_size, encoder_dim, tf=False, ado=False):
        # ...
        # Deep output layers

        # Advanced DO: Layers for transforming LSTM state, context vector and embedding for DO-RNN
        if self.use_advanced_deep_output:
            hidden_dim, intermediate_dim = self.embedding_size, self.embedding_size
            self.f_h = nn.Linear(hidden_dim, intermediate_dim)  # Transforms LSTM hidden state
            self.f_z = nn.Linear(encoder_dim, intermediate_dim)  # Transforms context vector
            self.f_out = nn.Linear(intermediate_dim, self.vocabulary_size)  # Transforms combined vector (sum of embedding, LSTM state, and context vector) to voc_size
            self.relu = nn.ReLU()  # Activation function
            self.dropout = nn.Dropout()

        # Simple DO: Layer for transforming LSTM state to vocabulary
        self.deep_output = nn.Linear(self.embedding_size, self.vocabulary_size)  # Maps LSTM outputs to vocabulary
        self.dropout = nn.Dropout()
        # ...

    def forward(self, img_features, captions):
        # ...
        for t in range(max_timespan):
            # ...
            # Generate word prediction
            if self.use_advanced_deep_output:
                if self.use_tf:
                    output = self.advanced_deep_output(self.dropout(h), context, caption_embedding[:, t])
                else:
                    output = self.advanced_deep_output(self.dropout(h), context, previous_predicted_token_embedding)
            else:
                output = self.deep_output(self.dropout(h))
            # ...

    def advanced_deep_output(self, h, context, current_embedding):
        # Combine the LSTM state and context vector
        h_transformed = self.relu(self.f_h(h))
        z_transformed = self.relu(self.f_z(context))

        # Sum the transformed vectors with the embedding
        combined = h_transformed + z_transformed + current_embedding

        # Transform the combined vector & compute the output word probability
        return self.relu(self.f_out(combined))

Integration of BERT Embeddings in the Decoder¶

In my implementation of the "Show, Attend and Tell" model, I explored the use of both standard embeddings and BERT embeddings in the decoder module. BERT, being a transformer-based model, provides contextually rich embeddings compared to standard word embeddings. Integrating BERT embeddings required modifications to the decoder architecture and data processing pipeline (see generate_json_data_bert.py and decoder.py for details).

BERT Data Preprocessing¶

  • Tokenization Differences: Unlike standard tokenization, which typically involves splitting text into words or subwords, BERT's tokenizer can further break down words into smaller units (word pieces). This feature of BERT tokenization is crucial because it helps in handling out-of-vocabulary words more effectively but also results in a higher number of tokens for the same text compared to the standard tokenization approach I used.
  • Increased Maximum Caption Length: Due to BERT's tokenization approach, I increased the max_caption_length to 30. This adjustment ensures that the increased token count, a consequence of BERT's finer granularity in tokenization, doesn't lead to excessive truncation of the captions.
  • BERT's Special Tokens: BERT utilizes specific special tokens, namely [CLS] (used at the beginning of a text to signify classification tasks) and [SEP] (used as a separator, e.g., between sentences). In my preprocessing routine, I aligned the tokenizer's bos_token (beginning of sequence) and eos_token (end of sequence) with BERT's [CLS] and [SEP] tokens respectively.

BERT Embedding Integration¶

  • BERT Model and Tokenizer: When bert=True, I utilized the BertModel and BertTokenizer from the transformers library. The model used is bert-base-uncased, which provides a good balance between performance and computational efficiency. Its uncased nature also makes it suitable for the flickr8k dataset, which contains only lower-case words in the captions.
  • Embedding Size and Vocabulary: I set the embedding_size to BERT's hidden size (768) and vocabulary_size to BERT's vocabulary size. This ensures compatibility with the pre-trained BERT model.
  • Embedding Layer Replacement: Instead of a standard embedding layer, BERT's own input embeddings are used. This ensures that the input tokens are represented using the rich, pre-trained context-aware embeddings from BERT.
  • Freezing BERT Embeddings: To maintain the integrity of the pre-trained embeddings and reduce training complexity, I froze the parameters of the BERT embedding layer.

Decoder Architecture Adjustments¶

  • Initialization of LSTM States: The LSTM cell states (h and c) are initialized using linear transformations of the encoder's output, followed by a tanh activation. This remains unchanged for BERT integration.
  • Attention Mechanism: The attention mechanism employed remains the same regardless of the embedding type. It computes context vectors based on the hidden state of the LSTM and the encoder features.
  • Teacher Forcing and Input Preparation: When using BERT embeddings, the initial token for decoding sequences is set to BERT's cls_token_id. This differs from the standard approach where a <start> token is used.
  • LSTM Cell Processing: The LSTM cell takes as input a concatenation of the current word embedding and the context vector. For BERT embeddings, the word embedding is directly obtained from the BERT embedding layer.
  • Advanced Deep Output (DO) Option: If use_advanced_deep_output is true, the model employs an enhanced deep output layer, which combines the transformed LSTM state, context vector, and current embedding to predict the next word. This mechanism is independent of the embedding type but requires careful dimensionality alignment (Standard embeddings are size 512, BERT 768).

Considerations for BERT Integration¶

  • Memory and Computational Efficiency: BERT models are resource-intensive. To mitigate this, I deleted the main BERT model after extracting its embedding layer and ensured that the embedding layer's parameters are frozen. This reduces the checkpoint size significantly and removes layers that are not used from the Decoder's state dictionary.
  • Compatibility with Existing Architecture: The decoder is designed to be flexible, allowing for both standard and BERT embeddings. This required consistent handling of embedding dimensions and vocabulary sizes.

Metrics¶

A quick introduction to the metrics used to evaluate the performance of the model.

BLEU Score¶

Overview of BLEU Score¶

BLEU (Bilingual Evaluation Understudy) Score is a widely used metric for evaluating the quality of text which has been machine-translated from one language to another. In the context of image captioning, it's adapted to assess the quality of generated captions compared to a set of reference captions.

Formula¶

The BLEU score is calculated based on n-gram precision. For each n-gram size (up to a predefined limit, typically 4), it compares the n-grams of the generated text with the n-grams of the reference texts, counting the number of matches. These matches are then adjusted by a brevity penalty to penalize overly short predictions. The formula for BLEU score is:

$$ \text{BLEU} = \text{BP} \cdot \exp\left(\sum_{n=1}^{N} w_n \log p_n\right) $$

Where:

  • $p_n$ is the precision of n-grams.
  • $w_n$ are weights for each n-gram size (usually uniform).
  • $\text{BP}$ is the brevity penalty, calculated as:
    • $\text{BP} = 1$ if candidate length > reference length.
    • $\text{BP} = e^{(1 - \text{reference length} / \text{candidate length})}$ if candidate length ≤ reference length.

Application¶

BLEU score provides a quantitative measure of the similarity between the machine-generated text and human-generated reference texts. Higher BLEU scores indicate better alignment with the reference captions, suggesting higher quality translations or captions. Scores between 0.6 and 0.7 are considered the best one can achieve (https://towardsdatascience.com).

Top-N Accuracy¶

Understanding Top-N Accuracy¶

Top-N accuracy is a performance metric used to evaluate classification models, including those in image captioning where the task can be viewed as predicting the next word in a sequence.

Formula¶

Top-N accuracy is computed as follows:

$$ \text{Top-N Accuracy} = \frac{\text{Number of times the correct label is in the top N predictions}}{\text{Total number of predictions}} $$

Application¶

  • In the context of image captioning, "Top-1 accuracy" means the model's most likely predicted word (highest probability) is the correct next word. "Top-5 accuracy" expands this to consider the top 5 predicted words.
  • Higher Top-N accuracy indicates a better performing model, as it frequently predicts the correct next word within its top N choices.

Experiment Setup¶

Overview¶

For my deep learning experiment, I focused on training four distinct variants of the "Show, Attend and Tell" model to investigate the impact of different configurations, particularly the presence of attention mechanisms and the use of BERT embeddings. The experiments were conducted using a controlled setup, with consistent hyperparameters and data splits.

Tracking and Logging with Weights & Biases (W&B)¶

  • W&B Integration: Throughout the training and evaluation phases, I utilized Weights & Biases (W&B). This allowed for efficient monitoring of the model's performance and hyperparameters, and easy retrieval of different model checkpoints for evaluation.
  • Metrics Tracked: The key metrics monitored included cross-entropy loss, BLEU Scores (n-grams 1 to 4), and accuracy (top-1 and top-5).

Dataset and Splits¶

  • Data Loader: The dataset used was defined by Karpathy's splits. This provided a standard and reliable basis for comparing model performance with the original paper.
  • Data Path: The dataset path was set to the Flickr8k data saved at 'data/flickr8k'.

Model Variants¶

  1. Plain with Attention (plain_att): This variant included attention, teacher forcing, and advanced deep output but did not utilize BERT embeddings.
  2. Plain without Attention (plain_noatt): Similar to the first variant but without the attention mechanism.
  3. BERT with Attention (bert_att): This variant employed both BERT embeddings and the attention mechanism, alongside teacher forcing and advanced deep output.
  4. BERT without Attention (bert_noatt): This setup utilized BERT embeddings but without the attention mechanism.

Controllable Hyperparameters¶

  • Batch Size, Epochs, and Learning Rate: The default batch size was set to 64, with 8 epochs of training and a learning rate of 1e-4 for the decoder.
  • Step Size and Alpha-C: The step size for learning rate annealing was set to 5, with an alpha-C regularization constant defaulting to 1.
  • Random Seed and Log Interval: A random seed of 42 was used for reproducibility, and the logging interval was set to every 50 batches.
  • Data Fraction and Network Choice: The full dataset was utilized (fraction=1.0), and the default network for the encoder was 'vgg19'.
  • Additional Flags: Flags for teacher forcing (--tf), advanced deep output (--ado), BERT embeddings (--bert), and attention mechanism (--attention) were used to toggle these features.

Experiment Focus¶

The primary objective of these experiments was to assess how different configurations (attention mechanism and BERT embeddings) influence the model's performance. By training these four variants under consistent conditions, I aimed to draw meaningful comparisons and insights into the individual and combined effects of attention and BERT embeddings on the image captioning task.

Training¶

Below, I delve into the key aspects of the training, including the distinctions between train, validation, and test modes, and other critical elements of the training pipeline.

Train, Validation, and Test Modes¶

  • Train Mode: In this mode, the encoder is set to eval mode (to use the pre-trained model without modification), and the decoder is set to train mode. This setup enables the model to learn from the training dataset. The key operations in this mode include processing images and captions, computing the forward pass, calculating loss (including attention regularization), performing backpropagation, and updating model parameters.
  • Validation Mode: Here, both the encoder and decoder are set to eval mode, ensuring no updates to the model parameters occur. The validation mode is crucial for assessing the model's performance on unseen data without the influence of training dynamics.
  • Test Mode: Similar to validation, both the encoder and decoder are in eval mode. The test mode is used for final evaluation of the model's performance, on the test dataset that the model has not seen during training or validation.

Key Components of the Training Procedure¶

  • Initialization and Configuration: The script initializes with a set seed for reproducibility, and W&B is configured for logging.
  • Tokenizer and Vocabulary: For bert=True, the BERT tokenizer and model are loaded, and the vocabulary size is set accordingly. Otherwise, the custom word dictionary is used.
  • Model Setup: The encoder and decoder are initialized based on the specified configurations, including the use of attention, teacher forcing, advanced deep output, and BERT embeddings.
  • Optimizer and Scheduler: The Adam optimizer and a learning rate scheduler are set up to optimize the decoder's parameters.
  • Data Loaders: Separate data loaders are prepared for training, validation, and test datasets.
  • Training Loop: For each epoch, the model is trained on the training dataset, validated on the validation dataset, and the learning rate is adjusted. Key metrics are logged at regular intervals.

Loss Calculation and Regularization¶

  • The loss function used is cross-entropy loss, and attention regularization is applied to encourage the model's attention to sum to 1 across the image features. In the original Show, Attend and Tell, this proved to be important for improving overall BLEU score, and qualitatively led to more rich and descriptive captions (see Section 4.2.1 of the paper).
  • In the training mode, loss is backpropagated to update model weights.

Accuracy and BLEU Score Calculation¶

  • Accuracy (Top-1 and Top-5) and BLEU scores are calculated in both validation and test modes to evaluate the model's performance.
  • Decoded captions are generated for comparison with reference captions to compute these metrics.

Logging and Visualization¶

  • W&B logs include loss, accuracy, BLEU scores, and prediction tables for each epoch.
  • The script supports logging attention visualizations, offering insights into the model's focus on different parts of the image while generating captions.

Code¶

Below are the most relevant parts of the training script (some code removed for brevity). The full training script is available at train.py.

class EvalMode(Enum):
    VALIDATION = 'val'
    TEST = 'test'

def main(args):
    set_seed(args.seed)
    wandb.init(project='show-attend-and-tell', entity='yvokeller', config=args)

    # ...

    encoder = Encoder(args.network)
    decoder = Decoder(vocabulary_size, encoder.dim, tf=args.tf, ado=args.ado, bert=args.bert, attention=args.attention)

    optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.step_size)
    cross_entropy_loss = nn.CrossEntropyLoss().to(mps_device)

    train_loader = DataLoader(
        ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='train'),
        batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)

    val_loader = DataLoader(
        ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='val'),
        batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)

    test_loader = DataLoader(
        ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='test'),
        batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)

    print(f'Starting training with {args}')
    for epoch in range(1, args.epochs + 1):
        train(epoch, encoder, decoder, optimizer, cross_entropy_loss,
              train_loader, word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer, args=args)
        validate(epoch, encoder, decoder, cross_entropy_loss, val_loader,
                 word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer)
        scheduler.step()

    if args.perform_test == True:
        test(epoch, encoder, decoder, cross_entropy_loss, test_loader,
             word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer)

    wandb.finish()


def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, bert=False, tokenizer=None, args={}):
    print(f"Epoch {epoch} - Starting train")

    encoder.eval()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (imgs, captions, _) in enumerate(data_loader):
        imgs, captions = Variable(imgs).to(mps_device), Variable(captions).to(mps_device)

        img_features = encoder(imgs)
        optimizer.zero_grad()
        preds, alphas = decoder(img_features, captions)
        targets = captions[:, 1:] # skip <start> token for loss calculation

        # Calculate accuracy
        padding_idx = word_dict['<pad>'] if bert == False else tokenizer.pad_token_id
        acc1 = sequence_accuracy(preds, targets, 1, ignore_index=padding_idx, tokenizer=tokenizer)
        acc5 = sequence_accuracy(preds, targets, 5, ignore_index=padding_idx, tokenizer=tokenizer)

        # Calculate loss
        packed_targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
        packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

        # encourage total attention (alphas) to be close to 1, thus penalize when sum is far from 1
        att_regularization = alpha_c * ((1 - alphas.sum(1)) ** 2).mean()

        loss = cross_entropy_loss(packed_preds, packed_targets)
        loss += att_regularization # pytorch autograd will calculate gradients for both loss and att_regularization
        loss.backward()
        optimizer.step()

        if bert == True:
            total_caption_length = calculate_caption_lengths(...)
        else:
            total_caption_length = calculate_caption_lengths(...)

        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            print(f'Train Batch: [{batch_idx}/{len(data_loader)}]\t'
                  f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                  f'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  f'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})')

        wandb.log({
            'train_loss': losses.avg, 'train_top1_acc': top1.avg, 'train_top5_acc': top5.avg, 'epoch': epoch,
            'train_loss_raw': losses.val, 'train_top1_acc_raw': top1.val, 'train_top5_acc_raw': top5.val
        })

def validate(epoch, *args, **kwargs):
    print(f"Epoch {epoch} - Starting validation")
    return run_evaluation(epoch, *args, mode=EvalMode.VALIDATION, **kwargs)

def test(epoch, *args, **kwargs):
    print(f"Epoch {epoch} - Starting test")
    return run_evaluation(epoch, *args, mode=EvalMode.TEST, **kwargs)

def run_evaluation(epoch, encoder, decoder, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, bert=False, tokenizer=None, mode=EvalMode.VALIDATION):
    encoder.eval()
    decoder.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    decoded_captions = [] # list of single assigned caption for each image
    decoded_all_captions = [] # list of list of all captions present in dataset for each image, thus captions may repeat in different lists
    decoded_hypotheses = [] # list of single predicted caption for each image

    predictions_table = wandb.Table(columns=["epoch", "mode", "target_caption", "pred_caption"])

    with torch.no_grad():
        logged_attention_visualizations_count = 0
        for batch_idx, (imgs, captions, all_captions) in enumerate(data_loader):
            imgs, captions = Variable(imgs).to(mps_device), Variable(captions).to(mps_device)
            img_features = encoder(imgs)
            preds, alphas = decoder(img_features, captions)
            targets = captions[:, 1:]

            # Calculate accuracy
            padding_idx = word_dict['<pad>'] if bert == False else tokenizer.pad_token_id
            acc1 = sequence_accuracy(preds, targets, 1, ignore_index=padding_idx, tokenizer=tokenizer)
            acc5 = sequence_accuracy(preds, targets, 5, ignore_index=padding_idx, tokenizer=tokenizer)

            # Calculate loss
            packed_targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
            packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

            att_regularization = alpha_c * ((1 - alphas.sum(1)) ** 2).mean()

            loss = cross_entropy_loss(packed_preds, packed_targets)
            loss += att_regularization

            if bert == True:
                total_caption_length = calculate_caption_lengths(...)
            else:
                total_caption_length = calculate_caption_lengths(...)

            losses.update(loss.item(), total_caption_length)
            top1.update(acc1, total_caption_length)
            top5.update(acc5, total_caption_length)

            if bert == True:
                # ... DECODE CAPTIONS BERT 
            else:
                # ... DECODE CAPTIONS STANDARD 

            if batch_idx % log_interval == 0:
                print(f'{mode} Batch: [{batch_idx}/{len(data_loader)}]\t'
                    f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                    f'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                    f'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})')

            if mode == EvalMode.TEST:
                # Calculate the start index for the current batch
                batch_start_idx = batch_idx * len(imgs)

                # Log the attention visualizations
                for img_idx, img_tensor in enumerate(imgs):
                    # Skip attention visualization if already logged enough
                    if logged_attention_visualizations_count >= 50:
                        break
                    logged_attention_visualizations_count += 1

                    # Calculate the global index for decoded_hypotheses and decoded_captions lists
                    global_caption_idx = batch_start_idx + img_idx

                    if len(decoded_hypotheses[global_caption_idx]) == 0:
                        print(f'No caption for image {global_caption_idx}, skipping attention visualization')
                        break

                    log_attention_visualization_plot(img_tensor, alphas, decoded_hypotheses, decoded_captions, batch_idx, img_idx, global_caption_idx, encoder)

        bleu_1 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(decoded_all_captions, decoded_hypotheses)

        wandb.log({
            'epoch': epoch,
            f'{epoch}_{mode.value}_caption_predictions': predictions_table,
            f'{mode.value}_loss': losses.avg, f'{mode.value}_top1_acc': top1.avg, f'{mode.value}_top5_acc': top5.avg,
            f'{mode.value}_loss_raw': losses.val, f'{mode.value}_top1_acc_raw': top1.val, f'{mode.value}_top5_acc_raw': top5.val,
            f'{mode.value}_bleu1': bleu_1, f'{mode.value}_bleu2': bleu_2, f'{mode.value}_bleu3': bleu_3, f'{mode.value}_bleu4': bleu_4,
        })

        print(f'{mode} Epoch: {epoch}\t'
              f'BLEU-1 ({bleu_1})\t'
              f'BLEU-2 ({bleu_2})\t'
              f'BLEU-3 ({bleu_3})\t'
              f'BLEU-4 ({bleu_4})\t')

Results¶

In this section, I explore the results for the following models, trained as described in chapter Experiment Setup:

  • Plain with Attention (plain-att-173): This variant included attention, teacher forcing, and advanced deep output but did not utilize BERT embeddings.
  • Plain without Attention (plain-noatt-175): Similar to the first variant but without the attention mechanism.
  • BERT with Attention (bert-att-176): This variant employed both BERT embeddings and the attention mechanism, alongside teacher forcing and advanced deep output.
  • BERT without Attention (bert-noatt-177): Similar to the previous variant but without the attention mechanism.

The detailed, intercative plots are available here: https://wandb.ai/yvokeller/show-attend-and-tell/reports/Image-Captioning-with-Attention--Vmlldzo2NDQ4Nzc5

Plot 1: Average Training vs. Validation Loss¶


  • Convergence: All model variants show a rapid decrease in loss at the beginning, which levels off as training progresses. This indicates that the model quickly learns the task initially but then experiences diminishing returns on learning as it converges.
  • BERT vs. Plain: Models with BERT embeddings (bert-) have a similar loss trajectory to those without (plain-). This suggests that the introduction of BERT embeddings did not drastically change the loss landscape for the model, but ends up with a slightly higher loss (~0.25 points).
  • Attention vs. No Attention: Models with attention mechanisms (-att-) also don't show a significantly different pattern in loss compared to those without (-noatt-). This suggests that the attention mechanism also does not drastically change the loss landscape for the model. This is a bit surprising, as the attention mechanism is a key component of the model and was shown to improve performance in the original paper. I take this as a signal that something could be wrong with my implementation of the attention mechanism.
  • Overfitting Check: Training and validation loss closely align for all model configurations, suggesting that overfitting is not occurring, as there is no divergence of validation loss from training loss.
  • Variability: There is significant fluctuation in the raw loss values, which is expected due to the variance in different mini-batches.

Plot 2: Accuracy¶

Top-1 Accuracy:

The top-1 accuracy for both training and validation phases across different configurations. The metric represents the percentage of times the model's highest probability prediction for the next word in the caption was indeed the correct word.

  • Analysis: All model variants start with a sharp increase in top-1 accuracy, which stabilizes as training progresses. The models reach and maintain a plateau early in training, indicating a quick adaptation to the most likely predictions.
  • Comparisons:
    • BERT vs. Plain: BERT variants perform an average of ~10% worse, suggesting that the inclusion of BERT embeddings in this setup is not beneficial for the model's ability to predict the most likely next word. This is surprising, and further hyperparameter tuning might be required to investigate this.
    • Attention vs. No Attention: Models with attention mechanisms do not show a marked improvement over those without in terms of top-1 accuracy, which could imply that the attention mechanism's benefit might lie in aspects other than choosing the most likely next word, the attention implementation is wrong, or it does not exist at all, which would go against Show, Attend and Tell's findings.

Top-5 Accuracy:

The top-5 accuracy plot reflects how often the correct next word appears within the top five predictions of the model.

  • Analysis: Similar to top-1 accuracy, there is a rapid improvement in top-5 accuracy at the beginning of training for all configurations, followed by a plateau. The top-5 accuracy is significantly higher than top-1, as expected, since there are more chances for the correct word to be in the top five predictions.
  • Comparisons:
    • BERT vs. Plain: As with top-1 accuracy, the BERT models perform worse in top-5 accuracy, suggesting that the rich contextual embeddings provided by BERT do not improve the model's ability to rank the correct next word within the top five predictions at all.
    • Attention vs. No Attention: The presence of an attention mechanism does not result in a significant difference in top-5 accuracy.

Plot 3: BLEU Scores¶

The plots show the BLEU-1 and BLEU-4 scores calculated during validation and test phases for the model variants. BLEU-1 is indicative of the unigram match between the predicted and reference captions, which is a measure of adequacy, while BLEU-4 considers longer n-gram matches up to four words, which indicates fluency.


Validation BLEU Scores:

  • Analysis: Across the training steps, we can observe an increase in BLEU scores, indicating improvement in the model's captioning performance as training progresses. The scores tend to plateau, suggesting that the models have reached their performance capacity on the validation set.

  • Model Comparisons:

    • BERT vs. Plain: The BERT-based models (bert-) show worse performance than the plain models in terms of BLEU-1 and BLEU-4 scores. The increased vocabulary size with BERT (from ~10000 to ~30000) with BERT embeddings could be a contributing factor to this, as the complexity of the task increases with the larger vocabulary. Another potential reason could be that other words from BERTs larger vocabulary are predicted, which are not present in the reference captions, leading to lower BLEU scores.
    • Attention vs. No Attention: For BLEU-1 and BLEU-4 scores, the models with attention (att-) demonstrate a modest improvement, indicating that attention potentially contributes to the adequacy and fluency of the generated captions.

Test BLEU Scores:

  • Model Performance: On the test set, the plain-att-173 model achieves the highest BLEU-1 and BLEU-4 scores.
  • Consistency: Generally, the consistency between validation and test scores for each model suggests that the models generalize well from the validation to the test set.

Comparisons with Original Paper:

The original paper reported BLEU-1 to BLEU-4 scores of 67, 44.8, 29.9, and 19.5, respectively, on the Flickr8k dataset. The best performing out of the 4 model variants (plain-att-173) achieves BLEU-1 to BLEU-4 scores of 65, 40, 23.4, and 13.3, respectively. Overall in comparison the performance degrades more with an increase in n-grams. This still looks quite promising, but the qualitative evaluation in the next section will provide a more comprehensive picture of the model's performance.

Overall Takeaways from Chart Analysis¶

  • BERT Integration: The integration of BERT embeddings did not yield the expected improvements in BLEU scores. The larger vocabulary size and possibly more complex output space may have contributed to this underperformance.
  • Attention Mechanism: Models with attention mechanisms berform slightly better, suggesting that the ability to focus on relevant parts of an image when generating captions is beneficial, but the differences are not significant. I would expect the attention mechanism to have a more substantial impact on the model's performance, so this would need further investigation.
  • Performance Plateau: There is a clear plateau in BLEU scores for all model variants, which may indicate that further training with the current setup might not yield significant improvements.
  • Model Generalization: The models demonstrate good generalization from validation to test sets, as indicated by the consistency in BLEU scores across these sets.

Proposed Next Steps¶

  • Vocabulary Adjustment: Experiment with reducing the vocabulary size used with BERT embeddings to see if this improves BLEU scores.
  • Model Architecture: Investigate if there are any mistakes in the implementation of the attention mechanism, which could explain the lack of significant performance improvements.
  • Hyperparameter Tuning: Perform hyperparameter tuning with a Sweep over params such as learning rate, batch size, and the alpha parameter for attention regularization.
  • Training Duration: Extend the number of training epochs or change the learning rate schedule to see if the models can overcome the performance plateau.

Bonus Models¶

Based on the observations from the training and evaluation, I trained two additional models to investigate the impact of different hyperparameters.

plain-lr-0.001-180:

A standard embedding model with attention, teacher forcing, and advanced deep output, trained for 8 epochs with an increased learning rate of 0.001.

  • The increased learning rate resulted in an even faster initial decrease in loss, and managed to outperform the other models on all metrics.
  • The loss is the lowest of all models, but not by a big margin, stopping at 2.123.
  • BLEU-1 ends at 64.6, BLEU-4 at 13.6. Thus BLEU-1 is 0.4 lower than the best model, but BLEU-4 is 0.3 higher.

plain-bs-exp-178:

A standard embedding model with attention, teacher forcing, and advanced deep output, but trained for 35 epochs with a batch size of 128.

  • This model variant performed better than the models with BERT embeddings, but worse than the two models with standard embeddings, suggesting that the increased batch size and training duration did not yield improvements.

Evaluation¶

This last section focuses on the qualitative evaluation of the model's performance, including the generation of captions and attention visualizations.

Script for Generating Captions¶

I developed a script, generate_caption.py, which is designed to generate and visualize image captions. This script is an essential part of my project as it not only generates captions but also provides a visual representation of the attention mechanism at work.

Key Features¶

  • BEAM Search Implementation: The script uses a BEAM search algorithm with a configurable number of beams. BEAM search is a heuristic search algorithm that explores a graph by expanding the most promising node in a limited set. In the context of caption generation, this means the script can generate more accurate and coherent captions by considering multiple top candidates at each step of the sequence generation.
  • Support for BERT and Standard Embeddings: Flexibility is provided in the form of support for both BERT embeddings and standard (custom) word embeddings.
  • Visualization of Attention (Alpha): The script includes functionality to visualize how the model's attention is distributed across different parts of the image for each word in the generated caption. This is crucial for understanding and interpreting the model's decision-making process.

Detailed Workflow¶

  1. Model Loading: The script can load models either from a specified path or from Weights & Biases (wandb).

  2. Tokenization: Depending on the configuration, the script uses either BERT tokenizer or a custom word dictionary for tokenizing the captions.

  3. Image Preprocessing: The input image is loaded and preprocessed to match the input format expected by the model. This includes resizing, normalization.

  4. Caption Generation and Visualization:

    • The preprocessed image is passed through the encoder to obtain image features.
    • The decoder then uses these features to generate a caption using the BEAM search algorithm.
    • For each word in the generated caption, the attention weights (alpha values) are visualized. This visualization is achieved by overlaying a heatmap on the original image, indicating the regions of the image the model focused on when generating each word.
  5. Command-Line Interface (CLI): The script is designed to be run from the command line, with arguments for specifying the image path and model details. It can also be used in a Jupyter Notebook, as demonstrated in the next section.

Code¶

The script is available at generate_caption.py.

Qualitative Evaluation¶

I'll use the following table to qualitatively judge the model's performance on the test set.

Criteria Rating: 1 Point Rating: 2 Points Rating: 3 Points Rating: 4 Points Rating: 5 Points
Adequacy Caption does not relate to the image content. Caption vaguely relates to the image. Caption covers basic elements in the image. Caption covers most elements in the image. Caption accurately covers all key elements of the image.
Semantic Correctness Caption makes no logical sense. Caption has significant logical flaws. Caption is somewhat logical but has some inaccuracies. Caption is logical with minor inaccuracies. Caption is completely logical and accurate.
Object Detection Fails to identify any objects. Identifies less than 50% of objects correctly. Identifies about 50% of objects correctly. Identifies most objects correctly. Identifies all objects correctly.
Color Detection Fails to identify any colors. Identifies less than 50% of colors correctly. Identifies about 50% of colors correctly. Identifies most colors correctly. Identifies all colors correctly.
Attention Visualization No attention visualization. Attention visualization is not meaningful. Attention visualization is somewhat meaningful. Attention visualization is mostly meaningful. Attention visualization is very meaningful.
In [1]:
from generate_caption import generate_caption_visualization, load_model

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
In [2]:
WANDB_PROJECT = "yvokeller/show-attend-and-tell/"
WANDB_MODEL_FOLDER = "model/"
WANDB_MODEL_NAME_TEMPLATE = "model_vgg19_X.pth"

def load_model_from_checkpoint(run_id, checkpoint):
    wandb_run = WANDB_PROJECT + run_id
    wandb_model = WANDB_MODEL_FOLDER + WANDB_MODEL_NAME_TEMPLATE.replace("X", str(checkpoint))
    
    return load_model(wandb_run=wandb_run, wandb_model=wandb_model)

def caption_images(img_paths, run_id, checkpoint, figsize=(9, 6), beam_size=3):
    encoder, decoder, bert, model_path, model_config_path = load_model_from_checkpoint(run_id, checkpoint)
    for img_path in img_paths:
        generate_caption_visualization(img_path, encoder, decoder, model_path, model_config_path, beam_size=beam_size, figsize=figsize)
In [3]:
# Image Test Sets
own_images = ['data/mine/train.jpeg', 'data/mine/tashi.jpeg', 'data/mine/lake.jpeg']
test_images_flickr8k = [
    'data/flickr8k/imgs/667626_18933d713e.jpg',
    'data/flickr8k/imgs/280706862_14c30d734a.jpg',
    'data/flickr8k/imgs/3072172967_630e9c69d0.jpg',
    'data/flickr8k/imgs/2654514044_a70a6e2c21.jpg',
    'data/flickr8k/imgs/311146855_0b65fdb169.jpg',
    'data/flickr8k/imgs/2218609886_892dcd6915.jpg',
    'data/flickr8k/imgs/2511019188_ca71775f2d.jpg',
    'data/flickr8k/imgs/2435685480_a79d42e564.jpg',
    'data/flickr8k/imgs/3482062809_3b694322c4.jpg'
]

Model plain-att-173¶

https://wandb.ai/yvokeller/show-attend-and-tell/runs/8nu0sdou

Remarks

  • Is able to identify animals and persons well (dog, gender of people)
  • Basketball correctly identified

Qualitative Evaluation

Criteria Rating (1-5)
Adequacy 4
Semantic Correctness 2
Object Detection 3
Color Detection 4
Attention Visualization 4

TOTAL 17

RATING 3.4

In [4]:
caption_images(test_images_flickr8k, '8nu0sdou', 8, beam_size=3)

Model plain-noatt-175¶

https://wandb.ai/yvokeller/show-attend-and-tell/runs/2asua9lu

Remarks

  • High level descriptions are somewhat related to the image
  • Basketball wrongly identified as a soccer
  • Young girl instead of street artists

Qualitative Evaluation

Criteria Rating (1-5)
Adequacy 3
Semantic Correctness 2
Object Detection 2
Color Detection 3
Attention Visualization N/A

TOTAL 10

RATING 2.0

In [12]:
caption_images(test_images_flickr8k, '2asua9lu', 8, beam_size=3)

Model bert-att-176¶

https://wandb.ai/yvokeller/show-attend-and-tell/runs/w97u6fc9

Remarks

  • Sentences mostly make no sense (are not semantically correct)
  • Things like girl, basketball and dog identified, but not used in a correct sentence (runs through the grass in a grass)
  • Tends to repeat itself talking about the same thing again and again (is is standing in the water of a water)
  • Attention seems suboptimal

Qualitative Evaluation

Criteria Rating (1-5)
Adequacy 2
Semantic Correctness 1
Object Detection 2
Color Detection 2
Attention Visualization 2

TOTAL 9

RATING 1.8

In [4]:
caption_images(test_images_flickr8k, 'w97u6fc9', 8, beam_size=3)

Model bert-noatt-177¶

https://wandb.ai/yvokeller/show-attend-and-tell/runs/aatq1i5p

Remarks

  • Sentences also mostly make no sense (a group of a white shirt),
  • Here the lack of attention is more obvious, its jumping between different objects/topics

Qualitative Evaluation

Criteria Rating (1-5)
Adequacy 1
Semantic Correctness 1
Object Detection 1
Color Detection 3
Attention Visualization N/A

TOTAL 8

RATING 1.5

In [5]:
caption_images(test_images_flickr8k, 'aatq1i5p', 8, beam_size=3)

Model plain-lr-0.001-180¶

And finally, the best performing model judging by BLEU scores trained with a higher learning rate.

https://wandb.ai/yvokeller/show-attend-and-tell/runs/tuzy19bt

Remarks

  • Features creative sentences that make sense
  • Identifies objects well (girl, dog, basketball, people), although not without mistakes (car instead of train)
  • Affine to details (a brown dog in the green watches, its a pot but nice try)
  • Two dogs (brown and black) correctly identified and described
  • Snow instead of beach detected (but its overexposed, so hard to tell)
  • Correctly describes complex szene with two people on a bench and another man in blue shirt (its pants though) watching!

Qualitative Evaluation

Criteria Rating (1-5)
Adequacy 4
Semantic Correctness 4
Object Detection 4
Color Detection 5
Attention Visualization 4

TOTAL 21

RATING 4.2

In [37]:
caption_images(test_images_flickr8k, 'tuzy19bt', 8, beam_size=3)

And for a final experiment, I will test the model on random pictures I shot.

In [26]:
caption_images(own_images, 'tuzy19bt', 8, beam_size=9)

Conclusion¶

Model plain-lr-0.001-180 performs best in terms of BLEU scores and qualitative evaluation, and is thus the best model out of the 5 trained models. It is able to identify objects and colors well, and generates captions that are semantically correct and adequate. The attention visualization is also meaningful, as it focuses on relevant parts of the image when generating captions. I like the creative sentences it generates, and it is able to describe complex scenes with multiple people well.

Model bert-noatt-177 performs worst in terms of BLEU scores and qualitative evaluation, and is thus the worst model out of the 5 trained models.